{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <center> 计算图和数据图可视化</center>\n",
    "\n",
    "\n",
    "## 计算图与数据图概述\n",
    "\n",
    "计算图的生成是通过将模型训练过程中的每个计算节点关联后所构成的，初体验者可以通过查看计算图，掌握整个模型的计算走向结构，数据流以及控制流的信息。对于高阶的使用人员，能够通过计算图验证计算节点的输入输出是否正确，并验证整个计算过程是否符合预期。数据图展示的是数据预处理的过程，在MindInsight可视化面板中可查看数据处理的图，能够更加直观地查看数据预处理的每一个环节，并帮助提升模型性能。\n",
    "\n",
    "接下来我们用一个图片分类的项目来体验计算图与数据图的生成与使用。\n",
    "        \n",
    "## 本次体验的整体流程\n",
    "1. 体验模型的数据选择使用MNIST数据集，MNIST数据集整体数据量比较小，更适合体验使用。\n",
    "\n",
    "2. 初始化一个网络，本次的体验使用LeNet网络。\n",
    "\n",
    "3. 增加可视化功能的使用，并设定只记录计算图与数据图。\n",
    "\n",
    "4. 加载训练数据集并进行训练，训练完成后，查看结果并保存模型文件。\n",
    "\n",
    "5. 启用MindInsight的可视化图界面，进行训练过程的核对。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集来源\n",
    "\n",
    "方法一\n",
    "\n",
    "从以下网址下载，并将数据包解压后放在Jupyter的工作目录下。\n",
    "\n",
    "- 训练数据集：{\"<http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz>\",\"<http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz>\"}\n",
    "- 测试数据集：{\"<http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz>\",\"<http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz>\"}\n",
    "\n",
    "可执行下面代码查看Jupyter的工作目录。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 训练数据集放在----`Jupyter工作目录+\\MNIST_Data\\train\\`，此时`train`文件夹内应该包含两个文件，`train-images-idx3-ubyte`和`train-labels-idx1-ubyte` \n",
    "- 测试数据集放在----`Jupyter工作目录+\\MNIST_Data\\test\\`，此时`test`文件夹内应该包含两个文件，`t10k-images-idx3-ubyte`和`t10k-labels-idx1-ubyte`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "方法二\n",
    "\n",
    "直接执行以下代码，会自动进行训练集的下载与解压，但是整个过程根据网络好坏情况会需要花费几分钟时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import gzip\n",
    "import urllib.request\n",
    "from urllib.parse import urlparse\n",
    "\n",
    "\n",
    "def unzip_file(gzip_path):\n",
    "    \"\"\"\n",
    "    Unzip a given gzip file.\n",
    "\n",
    "    Args:\n",
    "        gzip_path (str): The gzip file path\n",
    "    \"\"\"\n",
    "    open_file = open(gzip_path.replace('.gz', ''), 'wb')\n",
    "    gz_file = gzip.GzipFile(gzip_path)\n",
    "    open_file.write(gz_file.read())\n",
    "    gz_file.close()\n",
    "\n",
    "\n",
    "def download_dataset():\n",
    "    \"\"\"Download the dataset from http://yann.lecun.com/exdb/mnist/.\"\"\"\n",
    "    print(\"******Downloading the MNIST dataset******\")\n",
    "    train_path = \"./MNIST_Data/train/\"\n",
    "    test_path = \"./MNIST_Data/test/\"\n",
    "    train_path_check = os.path.exists(train_path)\n",
    "    test_path_check = os.path.exists(test_path)\n",
    "    if not train_path_check and not test_path_check:\n",
    "        os.makedirs(train_path)\n",
    "        os.makedirs(test_path)\n",
    "    train_url = {\"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\",\n",
    "                 \"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\"}\n",
    "    test_url = {\"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\",\n",
    "                \"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\"}\n",
    "\n",
    "    for url in train_url:\n",
    "        url_parse = urlparse(url)\n",
    "        # split the file name from url\n",
    "        file_name = os.path.join(train_path, url_parse.path.split('/')[-1])\n",
    "        if not os.path.exists(file_name.replace('.gz', '')) and not os.path.exists(file_name):\n",
    "            urllib.request.urlretrieve(url, file_name)\n",
    "        unzip_file(file_name)\n",
    "\n",
    "    for url in test_url:\n",
    "        url_parse = urlparse(url)\n",
    "        # split the file name from url\n",
    "        file_name = os.path.join(test_path, url_parse.path.split('/')[-1])\n",
    "        if not os.path.exists(file_name.replace('.gz', '')) and not os.path.exists(file_name):\n",
    "            urllib.request.urlretrieve(url, file_name)\n",
    "        unzip_file(file_name)\n",
    "\n",
    "download_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 数据增强\n",
    "对数据集进行数据增强操作，可以提升模型精度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset as ds\n",
    "import mindspore.dataset.transforms.vision.c_transforms as CV\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "from mindspore.dataset.transforms.vision import Inter\n",
    "from mindspore.common import dtype as mstype\n",
    "\n",
    "\n",
    "def create_dataset(data_path, batch_size=32, repeat_size=1,\n",
    "                   num_parallel_workers=1):\n",
    "    \"\"\"\n",
    "    Create dataset for train or test.\n",
    "\n",
    "    Args:\n",
    "        data_path (str): The absolute path of the dataset\n",
    "        batch_size (int): The number of data records in each group\n",
    "        repeat_size (int): The number of replicated data records\n",
    "        num_parallel_workers (int): The number of parallel workers\n",
    "    \"\"\"\n",
    "    # define dataset\n",
    "    mnist_ds = ds.MnistDataset(data_path)\n",
    "\n",
    "    # define some parameters needed for data enhancement and rough justification\n",
    "    resize_height, resize_width = 32, 32\n",
    "    rescale = 1.0 / 255.0\n",
    "    shift = 0.0\n",
    "    rescale_nml = 1 / 0.3081\n",
    "    shift_nml = -1 * 0.1307 / 0.3081\n",
    "\n",
    "    # according to the parameters, generate the corresponding data enhancement method\n",
    "    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
    "    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
    "    rescale_op = CV.Rescale(rescale, shift)\n",
    "    hwc2chw_op = CV.HWC2CHW()\n",
    "    type_cast_op = C.TypeCast(mstype.int32)\n",
    "\n",
    "    # using map method to apply operations to a dataset\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
    "    \n",
    "    # process the generated dataset\n",
    "    buffer_size = 10000\n",
    "    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script\n",
    "    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
    "    mnist_ds = mnist_ds.repeat(repeat_size)\n",
    "\n",
    "    return mnist_ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 可视化操作流程\n",
    "\n",
    "1. 准备训练脚本，在训练脚本中指定计算图的超参数信息，使用`Summary`保存到日志中，接着再运行训练脚本。\n",
    "\n",
    "2. 启动MindInsight，启动成功后，就可以通过访问命令执行后显示的地址，查看可视化界面。\n",
    "\n",
    "3. 访问可视化地址成功后，就可以对图界面进行查询等操作。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 初始化网络\n",
    "\n",
    "1. 导入构建网络所使用的模块。\n",
    "\n",
    "2. 构建初始化参数的函数。\n",
    "\n",
    "3. 创建网络，在网络中设置参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore.common.initializer import TruncatedNormal\n",
    "\n",
    "\n",
    "def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):\n",
    "    \"\"\"weight initial for conv layer\"\"\"\n",
    "    weight = weight_variable()\n",
    "    return nn.Conv2d(in_channels, out_channels,\n",
    "                     kernel_size=kernel_size, stride=stride, padding=padding,\n",
    "                     weight_init=weight, has_bias=False, pad_mode=\"valid\")\n",
    "\n",
    "\n",
    "def fc_with_initialize(input_channels, out_channels):\n",
    "    \"\"\"weight initial for fc layer\"\"\"\n",
    "    weight = weight_variable()\n",
    "    bias = weight_variable()\n",
    "    return nn.Dense(input_channels, out_channels, weight, bias)\n",
    "\n",
    "\n",
    "def weight_variable():\n",
    "    \"\"\"weight initial\"\"\"\n",
    "    return TruncatedNormal(0.02)\n",
    "\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \n",
    "    def __init__(self, num_class=10, channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.num_class = num_class\n",
    "        self.conv1 = conv(channel, 6, 5)\n",
    "        self.conv2 = conv(6, 16, 5)\n",
    "        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)\n",
    "        self.fc2 = fc_with_initialize(120, 84)\n",
    "        self.fc3 = fc_with_initialize(84, self.num_class)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.max_pool2d(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 执行训练\n",
    "\n",
    "1. 导入所需的代码包，并示例化训练网络。\n",
    "2. 通过MindSpore提供的 `SummaryCollector` 接口，实现收集计算图和数据图。在实例化 `SummaryCollector` 时，在 `collect_specified_data` 参数中，通过设置 `collect_graph` 指定收集计算图，设置 `collect_dataset_graph` 指定收集数据图。\n",
    "\n",
    "更多 `SummaryCollector` 的用法，请点击[API文档](https://www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.train.html?highlight=summarycollector#mindspore.train.callback.SummaryCollector)查看。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore import context\n",
    "from mindspore.train import Model\n",
    "from mindspore.nn.metrics import Accuracy\n",
    "from mindspore.train.callback import LossMonitor, SummaryCollector\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    device_target = \"CPU\"\n",
    "    \n",
    "    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)\n",
    "    download_dataset()\n",
    "    ds_train = create_dataset(data_path=\"./MNIST_Data/train/\")\n",
    "\n",
    "    network = LeNet5()\n",
    "    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=\"mean\")\n",
    "    net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n",
    "    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())\n",
    "    model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()})\n",
    "\n",
    "    specified={'collect_graph': True, 'collect_dataset_graph': True}\n",
    "    summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_specified_data=specified, collect_freq=1, keep_default_action=False)\n",
    "    \n",
    "    print(\"============== Starting Training ==============\")\n",
    "    model.train(epoch=2, train_dataset=ds_train, callbacks=[LossMonitor(), summary_collector], dataset_sink_mode=False)\n",
    "\n",
    "    print(\"============== Starting Testing ==============\")\n",
    "    ds_eval = create_dataset(\"./MNIST_Data/test/\")\n",
    "    acc = model.eval(ds_eval, dataset_sink_mode=False)\n",
    "    print(\"============== {} ==============\".format(acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 启动MindInsight\n",
    "- 启动MindInsigh服务命令：`mindinsigh start --summary-base-dir=/path/ --port=8080`；\n",
    "- 执行完服务命令后，访问给出的地址，查看MindInsigh可视化结果。\n",
    "\n",
    "> 其中 /path/ 为 `SummaryCollector` 中参数 `summary_dir` 所指定的目录。\n",
    "\n",
    "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/mindinsight_map.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 计算图信息\n",
    "- 文本选择框：输入计算图对应的路径及文件名，显示相应的计算图，便于查找文件。\n",
    "- 搜索框：可以对整体计算图的节点信息进行搜索，输入完整的节点名称，回车执行搜索，如果有该名称节点，就会呈现出来，便于查找节点。\n",
    "- 缩略图：展示整体计算图的缩略情况，在面板左边查看详细图结构时，在缩略图处会有定位，显示当前查看的位置在整体计算图中的定位，实时呈现部分与整体的关系。\n",
    "- 节点信息：显示当前所查看节点的信息，包括名称、类型、属性、输入和输出。便于在训练结束后，核对计算正确性时查看。\n",
    "- 图例：图例中包括命名空间、聚合节点、虚拟节点、算子节点、常量节点，通过不同图形来区分。\n",
    "\n",
    "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/cast_map.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据图展示\n",
    "\n",
    "数据图展示了数据增强中对数据进行操作的流程。\n",
    "\n",
    "1. 首先是从加载数据集 `mnist_ds = ds.MnistDataset(data_path)` 开始，对应数据图中 `MnistDataset`。\n",
    "\n",
    "2. 下面代码为上面的 `create_dataset` 函数中作数据预处理与数据增强的相关操作。可以从数据图中清晰地看到数据处理的流程。通过查看数据图，可以帮助分析是否存在不恰当的数据处理流程。\n",
    "\n",
    "```\n",
    "mnist_ds = mnist_ds.map(input_columns=\"label\", operations=type_cast_op, num_parallel_workers=num_parallel_workers)\n",
    "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=resize_op, num_parallel_workers=num_parallel_workers)\n",
    "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_op, num_parallel_workers=num_parallel_workers)\n",
    "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)\n",
    "mnist_ds = mnist_ds.map(input_columns=\"image\", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)\n",
    "\n",
    "mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script\n",
    "mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
    "mnist_ds = mnist_ds.repeat(repeat_size)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](https://gitee.com/mindspore/docs/raw/master/tutorials/notebook/mindinsight/images/data_map.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 关闭MindInsight\n",
    "\n",
    "- 查看完成后，在命令行中可执行此命令 `mindinsight stop --port=8080`，关闭MindInsight。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}